{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import metadsl\n",
    "import metadsl.python.pure as py_pure\n",
    "import metadsl.python.compat as py_compat\n",
    "import metadsl.numpy.pure as np_pure\n",
    "import metadsl.numpy.compat as np_compat\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Replacing adding zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@metadsl.expression\n",
    "def named_array(s: str) -> np_pure.NDArray:\n",
    "    ...\n",
    "\n",
    "\n",
    "@metadsl.expression\n",
    "def scalar_array(i: int) -> np_pure.NDArray:\n",
    "    ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(scalar_array, (0,))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zero = scalar_array(0)\n",
    "zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(named_array, ('some_array',))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n = named_array('some_array')\n",
    "n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Integer(_expression=Integer(Tuple.__getitem__, (metadsl.python.pure.Tuple[metadsl.python.pure.Integer](NDArray.shape, (NDArray(named_array, ('some_array',)),)), Integer(Integer.from_int, (10,)))))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np_compat.NDArray(n).shape()[10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@metadsl.rule\n",
    "def replace_add_zero(a: np_pure.NDArray):\n",
    "    return (\n",
    "        a + scalar_array(0),\n",
    "        a\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(NDArray.__add__, (NDArray(named_array, ('some_array',)), NDArray(scalar_array, (0,))))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n + zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(named_array, ('some_array',))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "replace_add_zero(n + zero)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## arange and indexing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(_expression=NDArray(arange, (metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.some, (Number(Number.from_number, (1,)),)), Number(Number.from_number, (10,)), metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.none_expr, ()), metadsl.python.pure.Optional[metadsl.numpy.pure.DType](Optional.none_expr, ()))))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np_compat.arange(\n",
    "    1,\n",
    "    10,\n",
    "    None,\n",
    "    None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Integer(_expression=Integer(Tuple.__getitem__, (metadsl.python.pure.Tuple[metadsl.python.pure.Integer](NDArray.shape, (NDArray(arange, (metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.some, (Number(Number.from_number, (1,)),)), Number(Number.from_number, (10,)), metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.none_expr, ()), metadsl.python.pure.Optional[metadsl.numpy.pure.DType](Optional.none_expr, ()))),)), Integer(Integer.from_int, (0,)))))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.shape()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(_expression=NDArray(NDArray.__add__, (NDArray(arange, (metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.some, (Number(Number.from_number, (1,)),)), Number(Number.from_number, (10,)), metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.none_expr, ()), metadsl.python.pure.Optional[metadsl.numpy.pure.DType](Optional.none_expr, ()))), NDArray(arange, (metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.some, (Number(Number.from_number, (1,)),)), Number(Number.from_number, (10,)), metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.none_expr, ()), metadsl.python.pure.Optional[metadsl.numpy.pure.DType](Optional.none_expr, ()))))))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a + a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(_expression=NDArray(NDArray.__getitem__, (NDArray(NDArray.__getitem__, (NDArray(arange, (metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.some, (Number(Number.from_number, (1,)),)), Number(Number.from_number, (10,)), metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.none_expr, ()), metadsl.python.pure.Optional[metadsl.numpy.pure.DType](Optional.none_expr, ()))), metadsl.python.pure.Union[metadsl.python.pure.Integer, metadsl.python.pure.Tuple[metadsl.python.pure.Integer]](Union.left_expr, (Integer(Integer.from_int, (10,)),)))), metadsl.python.pure.Union[metadsl.python.pure.Integer, metadsl.python.pure.Tuple[metadsl.python.pure.Integer]](Union.left_expr, (Integer(Integer.from_int, (10,)),)))))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idxed_twice = a[10][10]\n",
    "idxed_twice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(_expression=NDArray(NDArray.__getitem__, (NDArray(arange, (metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.some, (Number(Number.from_number, (1,)),)), Number(Number.from_number, (10,)), metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.none_expr, ()), metadsl.python.pure.Optional[metadsl.numpy.pure.DType](Optional.none_expr, ()))), metadsl.python.pure.Union[metadsl.python.pure.Integer, metadsl.python.pure.Tuple[metadsl.python.pure.Integer]](Union.right_expr, (metadsl.python.pure.Tuple[metadsl.python.pure.Integer](Tuple.from_items, (Integer(Integer.from_int, (10,)), Integer(Integer.from_int, (10,)))),)))))"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idxed_once = a[10, 10]\n",
    "idxed_once"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Merging array indexing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import metadsl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "@metadsl.rule\n",
    "def getitem_condense(a: np_pure.NDArray, left_idx: py_pure.Integer, right_idx: py_pure.Integer):\n",
    "    return (\n",
    "        a[\n",
    "            py_pure.Union.left(\n",
    "                py_pure.Integer, py_pure.Tuple[py_pure.Integer], left_idx\n",
    "            )\n",
    "        ][\n",
    "            py_pure.Union.left(\n",
    "                py_pure.Integer, py_pure.Tuple[py_pure.Integer], right_idx\n",
    "            )\n",
    "        ],\n",
    "        a[\n",
    "            py_pure.Union.right(\n",
    "                py_pure.Integer,\n",
    "                py_pure.Tuple[py_pure.Integer],\n",
    "                py_pure.Tuple.from_items(py_pure.Integer, left_idx, right_idx),\n",
    "            )\n",
    "        ],\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(NDArray.__getitem__, (NDArray(arange, (metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.some, (Number(Number.from_number, (1,)),)), Number(Number.from_number, (10,)), metadsl.python.pure.Optional[metadsl.python.pure.Number](Optional.none_expr, ()), metadsl.python.pure.Optional[metadsl.numpy.pure.DType](Optional.none_expr, ()))), metadsl.python.pure.Union[metadsl.python.pure.Integer, metadsl.python.pure.Tuple[metadsl.python.pure.Integer]](Union.right_expr, (metadsl.python.pure.Tuple[metadsl.python.pure.Integer](Tuple.from_items, (Integer(Integer.from_int, (10,)), Integer(Integer.from_int, (10,)))),))))"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "getitem_condense(idxed_twice._expression)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inferring shapes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@metadsl.expression\n",
    "def array_with_shape(shape: py_pure.Tuple[py_pure.Integer], a: np_pure.NDArray) -> np_pure.NDArray:\n",
    "    ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "@metadsl.expression\n",
    "def symbolic_shape(s: str) -> py_pure.Tuple[py_pure.Integer]:\n",
    "    ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "metadsl.python.pure.Tuple[metadsl.python.pure.Integer](symbolic_shape, ('some shape',))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shape = symbolic_shape(\"some shape\")\n",
    "shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NDArray(array_with_shape, (metadsl.python.pure.Tuple[metadsl.python.pure.Integer](symbolic_shape, ('some shape',)), NDArray(named_array, ('some_array',))))"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a_w_s = array_with_shape(shape, n)\n",
    "a_w_s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "a_w_s_idxed = a_w_s[\n",
    "    py_pure.Union.left(\n",
    "        py_pure.Integer,\n",
    "        py_pure.Tuple[py_pure.Integer],\n",
    "        py_pure.Integer.from_int(0),\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "@metadsl.pure_rule\n",
    "def int_indexing_shape(a: np_pure.NDArray, shape: py_pure.Tuple[py_pure.Integer], idx: py_pure.Integer):\n",
    "    \"\"\"\n",
    "    indexing by an integer removes the outer shape dimension\n",
    "    \"\"\"\n",
    "    return (\n",
    "        array_with_shape(shape, a)[\n",
    "            py_pure.Union.left(\n",
    "                py_pure.Integer, py_pure.Tuple[py_pure.Integer], idx\n",
    "            )\n",
    "        ],\n",
    "        array_with_shape(\n",
    "            shape.rest(),\n",
    "            a[\n",
    "                py_pure.Union.left(\n",
    "                    py_pure.Integer, py_pure.Tuple[py_pure.Integer], idx\n",
    "                )\n",
    "            ],\n",
    "        )\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NDArray.__getitem__(array_with_shape(symbolic_shape(some shape), named_array(some_array)), Union.left_expr(Integer.from_int(0)))\n"
     ]
    }
   ],
   "source": [
    "print(str(a_w_s_idxed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array_with_shape(Tuple.rest(symbolic_shape(some shape)), NDArray.__getitem__(named_array(some_array), Union.left_expr(Integer.from_int(0))))\n"
     ]
    }
   ],
   "source": [
    "print(str(int_indexing_shape(a_w_s_idxed)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
